import glob
import pymysql
import pandas as pd
import argparse
from rich.console import Console
from rich.table import Table
import cfg
import os

base_dir = os.path.dirname(os.path.abspath(__file__))

def check_file_extension(value):
    ext = os.path.splitext(value)[1].lower()  # 获取文件扩展名并转换为小写
    if ext not in ['.csv', '.xlsx']:
        raise argparse.ArgumentTypeError(f"File must have a .csv or .xlsx extension, but got '{ext}'")
    return value
    
parser = argparse.ArgumentParser(description="Process Excel files and update with database information.")
parser.add_argument('--input_csv', required=True, help='Path to input .xlsx file or directory containing .xlsx files.')
parser.add_argument('--output_path', required=True, type=check_file_extension, 
                    help='Path to save the updated file (must be .csv or .xlsx).')
args = parser.parse_args()

db = pymysql.connect(
    host=cfg.db["host"],
    port=cfg.db["port"],
    user=cfg.db["user"],
    password=cfg.db["password"],
    database=cfg.db["database"]
)
cursor = db.cursor()

file_paths = glob.glob(f"{args.input_csv}/*.xlsx") if args.input_csv.endswith('/') else [args.input_csv]
console = Console()

for file_path in file_paths:
    df = pd.read_csv(file_path)
    print("Excel 文件加载完成，共有行数:", len(df))
    df['ID'] = df['ID'].astype(str)
    df[['caller_num', 'callee_num', 'url']] = None

    begintimes = df['ID'].str[:14].unique()
    begintime_str = ','.join([f"'{bt[:4]}-{bt[4:6]}-{bt[6:8]} {bt[8:10]}:{bt[10:12]}:{bt[12:]}'" for bt in begintimes])
    query1 = f"SELECT begintime, customer_uuid, record_file_name FROM cti_record WHERE begintime IN ({begintime_str});"
    cursor.execute(query1)
    results1 = cursor.fetchall()
    begintime_map = {str(row[0]): (row[1], row[2]) for row in results1}
    customer_uuids = [row[1] for row in results1 if row[1] is not None]
    customer_uuid_str = ','.join([f"'{uuid}'" for uuid in customer_uuids])

    if customer_uuid_str:
        query2 = f"SELECT call_uuid, caller_num, callee_num FROM cti_cdr_call WHERE call_uuid IN ({customer_uuid_str});"
        cursor.execute(query2)
        results2 = cursor.fetchall()
        customer_uuid_map = {row[0]: (row[1], row[2]) for row in results2}
    else:
        customer_uuid_map = {}
        print("未找到任何有效的 customer_uuid，跳过相关查询。")

    for index, row in df.iterrows():
        begintime = row['ID'][:14]
        formatted_begintime = f"{begintime[:4]}-{begintime[4:6]}-{begintime[6:8]} {begintime[8:10]}:{begintime[10:12]}:{begintime[12:]}"
        if formatted_begintime in begintime_map:
            customer_uuid, record_file_name = begintime_map[formatted_begintime]
            if customer_uuid in customer_uuid_map:
                caller_num, callee_num = customer_uuid_map[customer_uuid]
                df.at[index, 'caller_num'] = caller_num
                df.at[index, 'callee_num'] = callee_num
                df.at[index, 'url'] = f"http://116.62.120.233{record_file_name}"
            else:
                print(f"未找到 customer_uuid {customer_uuid} 对应的 caller_num 和 callee_num。")
        else:
            print(f"未找到 begintime {formatted_begintime} 对应的 customer_uuid 和 record_file_name。")

    output_file_path = args.output_path
    if output_file_path.split('.')[-1] == 'xlsx':
        df.to_excel(output_file_path, index=False)
    elif output_file_path.split('.')[-1] == 'csv':
        df.to_csv(output_file_path, index=False)
   
    filtered_columns = [col for col in df.columns if 'ID' not in col and '原因' not in col and '关键' in col]
    main_table = Table(title="Evaluation Results", show_header=True)
    main_table.add_column("Category", justify="left", style="cyan", width=20)
    main_table.add_column("Value", justify="left", style="cyan", width=45)
    main_table.add_column("Count", justify="right", style="magenta", width=10)
    main_table.add_column("Percentage", justify="right", style="green", width=15)

    for col in filtered_columns:
        value_counts = df[col].value_counts()
        percentages = value_counts / value_counts.sum() * 100
        for i, (value, count) in enumerate(value_counts.items()):
            percentage = '{:.2f}%'.format(percentages[value])
            if i == 0:
                main_table.add_row(col, str(value), str(count), percentage)
            else:
                main_table.add_row("", str(value), str(count), percentage)

    console.print(main_table)

cursor.close()
db.close()
print("数据库连接已关闭")
